package com.chenzhiling.study.datasource.RemoteFile

import com.chenzhiling.study.datasource.FileSchema
import com.chenzhiling.study.util.FileUtil
import com.sshtools.net.SocketTransport
import com.sshtools.sftp.{SftpClient, SftpFile}
import com.sshtools.ssh.{PasswordAuthentication, SshAuthentication, SshConnector}
import com.sshtools.ssh2.Ssh2Client
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.{Partition, SparkContext, TaskContext}

import java.io.InputStream

/**
 * @Author: CHEN ZHI LING
 * @Date: 2021/8/26
 * @Description: 构建远程文件Rdd
 */
case class RemoteFilePartition(index: Int,path:String) extends Partition
class RemoteFileRdd(sc:SparkContext,
                    ip: String,
                    port: Int,
                    username: String,
                    password: String,
                    remotePath:String) extends RDD[Row](sc, Nil) with Serializable {

  override def compute(split: Partition, context: TaskContext): Iterator[Row] = {
    //建立ssh连接器
    val connector: SshConnector = SshConnector.createInstance()
    //建立通信
    val transport = new SocketTransport(ip, port)
    //获得ssh
    val ssh: Ssh2Client = connector.connect(transport, username)
    //密码验证
    val authentication = new PasswordAuthentication()
    authentication.setPassword(password)
    //获得连接结果
    val i: Int = ssh.authenticate(authentication)
    if(i!=SshAuthentication.COMPLETE) {
      throw new RuntimeException("连接服务器异常")
    }
    val sftpClient = new SftpClient(ssh)
    //获得所有的远程文件名
    val list: Array[SftpFile] = FileUtil.listRemoteFiles(remotePath, client = sftpClient)
    //转成rows
    val rows: Array[Row] = list.flatMap((file: SftpFile) => sftpFileToRow(file, sftpClient))
    //关闭连接
    sftpClient.isClosed
    ssh.disconnect()
    rows.toIterator
  }

  override protected def getPartitions: Array[Partition] = {
    FileSchema.getPartition(remotePath)
  }


  /**
   * sftpFile变row
   * @param file 远程服务器文件夹
   * @param sftpClient sftp客户端
   * @return
   */
  private def sftpFileToRow(file: SftpFile,sftpClient:SftpClient): Iterator[Row] = {
    //路径
    val path: String = file.getAbsolutePath
    //文件转流
    val stream: InputStream = sftpClient.getInputStream(file.getAbsolutePath)
    //文件名
    val fileName: String = file.getFilename
    //文件大小 UnsignedInteger64转long
    val size: Long = file.getAttributes.getSize.longValue()
    //切分成迭代器
    val iterator: Iterator[Array[Byte]] = FileUtil.splitRemoteStream(stream,size)
    //后缀名
    val suffix: String = FileUtil.getFileSuffix(fileName)
    val rows: Iterator[Row] = iterator.map((array: Array[Byte]) => FileSchema.filledSchema(fileName, path, suffix, size, array))
    rows
  }
}
